import matplotlib.pyplot as plt
import numpy as np

Indices = [1, 2, 3, 4]

GD = [7528, 4669, 5071, 1634]
SGD = [2372, 2034, 3443, 1653]
NoisyGD = [7385, 4547, 5032, 1641]
CVSGD = [3.0917997273102337e-10, 2.845301097042596e-10, 2.0790134185726798e-10, 2.3279583815162888e-10, 2.4676310984666867e-10]
ANTICVSGD = [3.732160466698911e-10, 2.970398358205874e-10, 2.5251350990720937e-10, 1.673754170258415e-10, 3.18640200863529e-10]
#AccGradientVarianceRatio4 = [2228/2283, 2338/2427, 2387/2462, 10316/11035, 2425/2543]
#GeneralizationGapRatio8 = [1.6298/1.6407, 1.2680/1.2792, 1.0600/1.0690, 1.3886/1.3494, 1.6543/1.6872]
#AccGradientVarianceRatio8 = [4200/4269, 4663/4862, 3489/3613, 3817/4107, 4576/4881]
#GeneralizationGapRatio10 = [0.8525/0.8506, 0.8455/0.8515, 0.6763/0.6851, 0.7181/0.7222, 0.7534/0.7716]
#AccGradientVarianceRatio10 = [2284/2283, 2398/2427, 2441/2462, 2021/2059, 2494/2543]
'''
plt.plot(Indices, GeneralizationGapRatio2, color='b', linestyle='solid', label='Gene Gap Ratio, bs=2')
plt.plot(Indices, GeneralizationGapRatio4, color='r', linestyle='solid', label='Gene Gap Ratio, bs=4')
#plt.plot(Indices, GeneralizationGapRatio8, color='k', linestyle='solid', label='Gene Gap Ratio, bs=8')
plt.plot(Indices, GeneralizationGapRatio10, color='y', linestyle='solid', label='Gene Gap Ratio, bs=10')
plt.plot(Indices, AccGradientVarianceRatio2, color='b', linestyle='dashed', label='Accu Grad Variance Ratio, bs=2')
plt.plot(Indices, AccGradientVarianceRatio4, color='r', linestyle='dashed', label='Accu Grad Variance Ratio, bs=4')
#plt.plot(Indices, AccGradientVarianceRatio8, color='k', linestyle='dashed', label='Accu Grad Variance Ratio, bs=8')
plt.plot(Indices, AccGradientVarianceRatio10, color='y', linestyle='dashed', label='Accu Grad Variance Ratio, bs=10')
plt.legend()
#plt.yscale('log')
plt.title("Ratio between the Second Moments and the Squared First Moments of SGD (d={}, N={}, 5-sparse beta)".format(str(40),str(40)))
plt.xlabel("Iteration")
plt.ylabel("Ratio")
plt.savefig("Summary")
plt.close()
'''
# set width of bar
barWidth = 0.2
fig = plt.subplots(figsize =(10, 8))

# Set position of bar on X axis
br1 = np.arange(len(GD))
br2 = [x + barWidth for x in br1]
br3 = [x + barWidth for x in br2]
#br4 = [x + barWidth for x in br3]
#br5 = [x + barWidth for x in br4]
#br6 = [x + barWidth for x in br5]

# Make the plot
plt.bar(br1, GD, color ='r', width = barWidth,
        edgecolor ='grey', label ='GD')
plt.bar(br2, SGD, color ='g', width = barWidth,
        edgecolor ='grey', label ='SGD')
plt.bar(br3, NoisyGD, color ='b', width = barWidth,
        edgecolor ='grey', label ='NoisyGD')
#plt.bar(br4, ANTICVSGD, color ='y', width = barWidth,
#        edgecolor ='grey', label ='ANTI-CV-SGD')
#plt.bar(br4, AccGradientVarianceRatio2, color ='r', width = barWidth,
#        edgecolor ='grey', label ='AccGrad2')
#plt.bar(br5, AccGradientVarianceRatio4, color ='g', width = barWidth,
#        edgecolor ='grey', label ='AccGrad4')
#plt.bar(br6, AccGradientVarianceRatio10, color ='b', width = barWidth,
#        edgecolor ='grey', label ='AccGrad10')
#plt.yscale('log')
# Adding Xticks
plt.xlabel('Setting', fontweight ='bold', fontsize = 15)
plt.ylabel('Product', fontweight ='bold', fontsize = 15)
plt.xticks([r + barWidth for r in range(len(GD))],
        ['x=-3,bs=1', 'x=-4,bs=1', 'x=-3,bs=2', 'x=-4.bs=2'])

#plt.legend()
plt.legend(loc=(0.9, 0.95))
plt.savefig('Bar Chart')
plt.close()

lambdas = [0, 4, 5, 6, 7, 8, 9, 10, 11, 11.3, 11.5, 11.6, 11.7, 11.8, 11.9, 12, 12.1, 12.2, 12.3, 12.4, 12.5, 13]
SGDwOriginRegTestLoss = [6.039+6.605+3.622, 4.214+4.388+1.666, 3.86147+4.0330+1.3890, 3.2305+3.522+0.9477, 3.230+3.522+0.947,
                         2.9571+3.319+0.745, 2.672+3.12+0.530, 2.35+2.926+0.3597, 1.997+2.771+0.260, 1.877+2.704+0.252, 1.7946+2.655+0.206,
                         1.7525+2.63+0.247, 1.7101+2.603+0.238, 1.673+2.606+0.3093, 1.62+2.54+0.212, 1.60+2.545+0.376, 1.545+2.58+1.88, 1.486+2.541+2.12, 1.451+2.488+4.13,
                         1.407+2.454+5.30, 1.4259+2.4647+5.319, 6.324+7.37+5.2779]
for i in range(len(SGDwOriginRegTestLoss)):
    SGDwOriginRegTestLoss[i] /= 3
lambda4=[5.5]
SGDwReg4TestLoss = [(1.2957+2.31+0.168)/3]
lambda1=[4]
SGDwReg1TestLoss=[(1.263+2.267+0.316)/3]
lambda2=[4.5]
SGDwReg2TestLoss=[(1.381+2.41+0.443)/3]
lambda3=[5]
SGDwReg3TestLoss = [(1.37+2.414+0.379)/3]

plt.plot(lambdas, SGDwOriginRegTestLoss, color='r', linestyle='solid', label="SGDwOriginReg")
plt.scatter(lambda1, SGDwReg1TestLoss, s=10, facecolors='none', edgecolors='k')#, label='(4,1.375)')
plt.scatter(lambda2, SGDwReg2TestLoss, s=10, facecolors='none', edgecolors='teal')
plt.scatter(lambda3, SGDwReg3TestLoss, s=10, facecolors='none', edgecolors='b')#, label='(5,1.175)')
plt.scatter(lambda4, SGDwReg4TestLoss, s=10, facecolors='none', edgecolors='grey')#, label='(5.5,1.125)')
plt.axhline(y = SGDwReg1TestLoss, color = 'k', linestyle = '-', label='(4,1.375)')
plt.axhline(y = SGDwReg2TestLoss, color = 'teal', linestyle = '-', label='(4.5,1.25)')
plt.axhline(y = SGDwReg3TestLoss, color = 'b', linestyle = '-', label='(5,1.175)')
plt.axhline(y = SGDwReg4TestLoss, color = 'grey', linestyle = '-', label='(5.5,1.125)')
plt.axhline(y = (6.1533+ 6.753+3.787) / 3, color = 'y', linestyle = '-', label='GD')
plt.axhline(y = (6.039+6.605+3.622) / 3, color = 'g', linestyle = '-', label='SGD')
plt.yscale('log')
plt.legend()
plt.title("Test Losses of Algorithms with DLN")
plt.xlabel("Lambda")
plt.ylabel("Test Loss")
plt.savefig("TestLosses")
plt.close()

'''
plt.plot(lambdas, GDGeneGap, color='b', label="GD")
plt.plot(lambdas, SGDGeneGap, color='y', label="SGD")
plt.plot(lambdas, SGDwRegGeneGap, color='r', label="SGDwReg")
plt.legend()
plt.title("Generalization Gaps of Algorithms with DLN")
plt.xlabel("Lambda")
plt.ylabel("Generalization Gap")
plt.savefig("GeneGaps")
plt.close()
'''
print(((1.2957+2.31+0.168))/(1.60+2.545+0.376))